import os
import urllib
import traceback
import time
import sys
import numpy as np
import cv2
from rknn.api import RKNN
import threading
from queue import Queue
from spirems import Subscriber, Publisher, cvimg2sms, sms2cvimg, def_msg, QoS, BaseNode, get_extra_args
from spirems.mod_helper import download_model
import argparse
from typing import Union
import platform
from copy import copy
import json
from spirecv.algorithm.utils import calc_fov, calc_los_pos


# Model from https://github.com/airockchip/rknn_model_zoo


class RKNN_model_container():
    def __init__(self, model_path, target=None, device_id=None) -> None:
        rknn = RKNN()

        # Direct Load RKNN Model
        rknn.load_rknn(model_path)

        print('--> Init runtime environment')
        if target==None:
            ret = rknn.init_runtime()
        else:
            ret = rknn.init_runtime(target=target, device_id=device_id)
        if ret != 0:
            print('Init runtime environment failed')
            exit(ret)
        print('done')
        
        self.rknn = rknn

    # def __del__(self):
    #     self.release()

    def run(self, inputs):
        if self.rknn is None:
            print("ERROR: rknn has been released")
            return []

        if isinstance(inputs, list) or isinstance(inputs, tuple):
            pass
        else:
            inputs = [inputs]

        result = self.rknn.inference(inputs=inputs)
    
        return result

    def release(self):
        self.rknn.release()
        self.rknn = None


class Letter_Box_Info():
    def __init__(self, shape, new_shape, w_ratio, h_ratio, dw, dh, pad_color) -> None:
        self.origin_shape = shape
        self.new_shape = new_shape
        self.w_ratio = w_ratio
        self.h_ratio = h_ratio
        self.dw = dw 
        self.dh = dh
        self.pad_color = pad_color


def coco_eval_with_json(anno_json, pred_json):
    from pycocotools.coco import COCO
    from pycocotools.cocoeval import COCOeval
    anno = COCO(anno_json)
    pred = anno.loadRes(pred_json)
    eval = COCOeval(anno, pred, 'bbox')
    # eval.params.useCats = 0
    # eval.params.maxDets = list((100, 300, 1000))
    # a = np.array(list(range(50, 96, 1)))/100
    # eval.params.iouThrs = a
    eval.evaluate()
    eval.accumulate()
    eval.summarize()
    map, map50 = eval.stats[:2]  # update results (mAP@0.5:0.95, mAP@0.5)

    print('map  --> ', map)
    print('map50--> ', map50)
    print('map75--> ', eval.stats[2])
    print('map85--> ', eval.stats[-2])
    print('map95--> ', eval.stats[-1])


class COCO_test_helper():
    def __init__(self, enable_letter_box = False) -> None:
        self.record_list = []
        self.enable_ltter_box = enable_letter_box
        if self.enable_ltter_box is True:
            self.letter_box_info_list = []
        else:
            self.letter_box_info_list = None

    def letter_box(self, im, new_shape, pad_color=(0,0,0), info_need=False):
        # Resize and pad image while meeting stride-multiple constraints
        shape = im.shape[:2]  # current shape [height, width]
        if isinstance(new_shape, int):
            new_shape = (new_shape, new_shape)

        # Scale ratio
        r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])

        # Compute padding
        ratio = r  # width, height ratios
        new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
        dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding

        dw /= 2  # divide padding into 2 sides
        dh /= 2

        if shape[::-1] != new_unpad:  # resize
            im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
        top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
        left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
        im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=pad_color)  # add border
        
        if self.enable_ltter_box is True:
            self.letter_box_info_list.append(Letter_Box_Info(shape, new_shape, ratio, ratio, dw, dh, pad_color))
        if info_need is True:
            return im, ratio, (dw, dh)
        else:
            return im

    def direct_resize(self, im, new_shape, info_need=False):
        shape = im.shape[:2]
        h_ratio = new_shape[0] / shape[0]
        w_ratio = new_shape[1] / shape[1]
        if self.enable_ltter_box is True:
            self.letter_box_info_list.append(Letter_Box_Info(shape, new_shape, w_ratio, h_ratio, 0, 0, (0,0,0)))
        im = cv2.resize(im, (new_shape[1], new_shape[0]))
        return im

    def get_real_box(self, box, in_format='xyxy'):
        bbox = copy(box)
        if self.enable_ltter_box == True:
        # unletter_box result
            if in_format=='xyxy':
                bbox[:,0] -= self.letter_box_info_list[-1].dw
                bbox[:,0] /= self.letter_box_info_list[-1].w_ratio
                bbox[:,0] = np.clip(bbox[:,0], 0, self.letter_box_info_list[-1].origin_shape[1])

                bbox[:,1] -= self.letter_box_info_list[-1].dh
                bbox[:,1] /= self.letter_box_info_list[-1].h_ratio
                bbox[:,1] = np.clip(bbox[:,1], 0, self.letter_box_info_list[-1].origin_shape[0])

                bbox[:,2] -= self.letter_box_info_list[-1].dw
                bbox[:,2] /= self.letter_box_info_list[-1].w_ratio
                bbox[:,2] = np.clip(bbox[:,2], 0, self.letter_box_info_list[-1].origin_shape[1])

                bbox[:,3] -= self.letter_box_info_list[-1].dh
                bbox[:,3] /= self.letter_box_info_list[-1].h_ratio
                bbox[:,3] = np.clip(bbox[:,3], 0, self.letter_box_info_list[-1].origin_shape[0])
        return bbox

    def get_real_seg(self, seg):
        #! fix side effect
        dh = int(self.letter_box_info_list[-1].dh)
        dw = int(self.letter_box_info_list[-1].dw)
        origin_shape = self.letter_box_info_list[-1].origin_shape
        new_shape = self.letter_box_info_list[-1].new_shape
        if (dh == 0) and (dw == 0) and origin_shape == new_shape:
            return seg
        elif dh == 0 and dw != 0:
            seg = seg[:, :, dw:-dw] # a[0:-0] = []
        elif dw == 0 and dh != 0 : 
            seg = seg[:, dh:-dh, :]
        seg = np.where(seg, 1, 0).astype(np.uint8).transpose(1,2,0)
        seg = cv2.resize(seg, (origin_shape[1], origin_shape[0]), interpolation=cv2.INTER_LINEAR)
        if len(seg.shape) < 3:
            return seg[None,:,:]
        else:
            return seg.transpose(2,0,1)

    def add_single_record(self, image_id, category_id, bbox, score, in_format='xyxy', pred_masks = None):
        if self.enable_ltter_box == True:
        # unletter_box result
            if in_format=='xyxy':
                bbox[0] -= self.letter_box_info_list[-1].dw
                bbox[0] /= self.letter_box_info_list[-1].w_ratio

                bbox[1] -= self.letter_box_info_list[-1].dh
                bbox[1] /= self.letter_box_info_list[-1].h_ratio

                bbox[2] -= self.letter_box_info_list[-1].dw
                bbox[2] /= self.letter_box_info_list[-1].w_ratio

                bbox[3] -= self.letter_box_info_list[-1].dh
                bbox[3] /= self.letter_box_info_list[-1].h_ratio
                # bbox = [value/self.letter_box_info_list[-1].ratio for value in bbox]

        if in_format=='xyxy':
        # change xyxy to xywh
            bbox[2] = bbox[2] - bbox[0]
            bbox[3] = bbox[3] - bbox[1]
        else:
            assert False, "now only support xyxy format, please add code to support others format"
        
        def single_encode(x):
            from pycocotools.mask import encode
            rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
            rle["counts"] = rle["counts"].decode("utf-8")
            return rle

        if pred_masks is None:
            self.record_list.append({"image_id": image_id,
                                    "category_id": category_id,
                                    "bbox":[round(x, 3) for x in bbox],
                                    'score': round(score, 5),
                                    })
        else:
            rles = single_encode(pred_masks)
            self.record_list.append({"image_id": image_id,
                                    "category_id": category_id,
                                    "bbox":[round(x, 3) for x in bbox],
                                    'score': round(score, 5),
                                    'segmentation': rles,
                                    })
    
    def export_to_json(self, path):
        with open(path, 'w') as f:
            json.dump(self.record_list, f)


class YOLOv11DetNode_Rknn(threading.Thread, BaseNode):
    def __init__(
        self,
        job_name: str,
        ip: str = '127.0.0.1',
        port: int = 9094,
        param_dict_or_file: Union[dict, str] = None,
        sms_shutdown: bool = True,
        **kwargs
    ):
        threading.Thread.__init__(self)
        sms_shutdown = True if sms_shutdown in ['True', 'true', '1', True] else False
        BaseNode.__init__(
            self,
            self.__class__.__name__,
            job_name,
            ip=ip,
            port=port,
            param_dict_or_file=param_dict_or_file,
            sms_shutdown=sms_shutdown,
            **kwargs
        )
        self.launch_next_emit = self.get_param("launch_next_emit", True)
        self.specified_input_topic = self.get_param("specified_input_topic", "")
        self.specified_output_topic = self.get_param("specified_output_topic", "")
        self.realtime_det = self.get_param("realtime_det", True)
        self.remote_ip = self.get_param("remote_ip", "127.0.0.1")
        self.remote_port = self.get_param("remote_port", 9094)
        self.confidence = self.get_param("confidence", 0.25)
        self.nms_thresh = self.get_param("nms_thresh", 0.45)
        self.imgsz = self.get_param("imgsz", [640, 640])
        self.dataset_name = self.get_param("dataset_name", "coco_detection")
        self.objs_in_meter = self.get_param("objs_in_meter", {"person": [-1, 1.8], "keyboard": [0.43, -1]})  # {category_name: [w, h], ...}
        self.model_path = self.get_param("model_path", "yolo11.rknn")
        self.target = self.get_param("target", "rk3588")
        self.device_id = self.get_param("device_id", "")
        self.use_shm = self.get_param("use_shm", -1)
        self.g_dataset_categories = self.get_param("/dataset_categories", {})
        self.params_help()

        self.b_use_shm = False
        if self.use_shm == 1 or (self.use_shm == -1 and platform.system() == 'Linux'):
            self.b_use_shm = True

        if self.model_path.startswith("sms::"):
            self.local_model_path = download_model(self.__class__.__name__, self.model_path)
            assert self.local_model_path is not None
        else:
            self.local_model_path = self.model_path

        self.dataset_categories = self.g_dataset_categories[self.dataset_name]
        self.model, self.platform = self.setup_model()
        self.co_helper = COCO_test_helper(enable_letter_box=True)

        input_url = '/' + job_name + '/sensor/image_raw'
        if len(self.specified_input_topic) > 0:
            input_url = self.specified_input_topic

        output_url = '/' + job_name + '/detector/results'
        if len(self.specified_output_topic) > 0:
            output_url = self.specified_output_topic
        
        calib_url = '/' + job_name + '/sensor/calibration_info'

        self.job_queue = Queue()
        self.queue_pool.append(self.job_queue)
    
        self.calib_width, self.calib_height = -1, -1
        self.camera_matrix = [712.12, 0,645.23, 0, 705.87, 327.34, 0, 0, 1]
        self.camera_matrix = np.array(self.camera_matrix).reshape(3, 3)
        self.distortion = [0.0, 0.0, 0.0, 0.0, 0.0]
        self.distortion = np.array(self.distortion)

        self._image_reader = Subscriber(
            input_url, 'std_msgs::Null', self.image_callback,
            ip=ip, port=port, qos=QoS.Reliability
        )
        self._calibration_reader = Subscriber(
            calib_url, 'sensor_msgs::CameraCalibration', self.calibration_callback,
            ip=ip, port=port, qos=QoS.Reliability
        )
        self._result_writer = Publisher(
            output_url, 'spirecv_msgs::2DTargets',
            ip=self.remote_ip, port=self.remote_port, qos=QoS.Reliability
        )
        self._show_writer = Publisher(
            '/' + job_name + '/detector/image_results', 'memory_msgs::RawImage' if self.b_use_shm else 'sensor_msgs::CompressedImage',
            ip=ip, port=port
        )
        if self.launch_next_emit:
            self._next_writer = Publisher(
                '/' + job_name + '/launch_next', 'std_msgs::Boolean',
                ip=ip, port=port, qos=QoS.Reliability
            )

        self.start()
    
    def trans_det_results(self, boxes, classes, scores, h, w, camera_matrix, calib_wh, objs_in_meter, roi=None):
        sms_results = def_msg('spirecv_msgs::2DTargets')

        sms_results["file_name"] = ""
        sms_results["height"] = h
        sms_results["width"] = w
        has_calib = False
        if calib_wh[0] > 0 and calib_wh[1] > 0:
            sms_results["fov_x"], sms_results["fov_y"] = calc_fov(camera_matrix, calib_wh)
            has_calib = True
    
        sms_results["targets"] = []
        if roi is not None:
            sms_results["rois"] = [[roi[0], roi[1], roi[2] - roi[0], roi[3] - roi[1]]]

        if boxes is not None:
            boxes = self.co_helper.get_real_box(boxes)
            for i in range(len(boxes)):
                ann = dict()
                name = self.dataset_categories[int(classes[i])].strip()
                ann["category_name"] = name.replace(' ', '_').lower()
                ann["category_id"] = int(classes[i])
                ann["score"] = float(round(scores[i], 3))
                ann["bbox"] = [round(j, 3) for j in boxes[i].tolist()]
                ann["bbox"][2] = ann["bbox"][2] - ann["bbox"][0]
                ann["bbox"][3] = ann["bbox"][3] - ann["bbox"][1]
                ann["cxy"] = [
                    (ann["bbox"][0] + ann["bbox"][2] / 2.) / sms_results["width"], 
                    (ann["bbox"][1] + ann["bbox"][3] / 2.) / sms_results["height"]
                ]
                if has_calib and name in objs_in_meter:
                    ann["los"], ann["pos"] = calc_los_pos(
                        camera_matrix, calib_wh, 
                        ann["cxy"], [ann["bbox"][2], ann["bbox"][3]], 
                        objs_in_meter[name]
                    )
                if roi is not None:
                    ann["bbox"][0] += roi[0]
                    ann["bbox"][1] += roi[1]
                sms_results["targets"].append(ann)

        return sms_results

    def release(self):
        BaseNode.release(self)
        self._image_reader.kill()
        self._result_writer.kill()
        self._show_writer.kill()
        self._next_writer.kill()

    def image_callback(self, msg):
        if self.realtime_det:
            if not self.job_queue.empty():
                self.job_queue.queue.clear()
        img = sms2cvimg(msg)
        self.job_queue.put({'msg': msg, 'img': img})
    
    def calibration_callback(self, msg):
        self.calib_width = msg['width']
        self.calib_height = msg['height']

        self.camera_matrix = np.array(msg['K']).reshape(3, 3)
        self.distortion = np.array(msg['D'])

    def run(self):
        while self.is_running():
            msg_dict = self.job_queue.get(block=True)
            if msg_dict is None:
                break

            msg, img_src = msg_dict['msg'], msg_dict['img']

            if "rois" in msg and len(msg["rois"]) > 0:
                roi = msg["rois"][0]
                img_infer = img_src[roi[1]: roi[3], roi[0]: roi[2], :]
            else:
                roi = None
                img_infer = img_src

            t1 = time.time()
            file_name = msg['file_name'] if 'file_name' in msg else ''

            # DO Object Detection
            pad_color = (0, 0, 0)
            img = self.co_helper.letter_box(im=img_infer.copy(), new_shape=(self.imgsz[0], self.imgsz[1]), pad_color=(0, 0, 0))
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

            # preprocee if not rknn model
            if self.platform in ['pytorch', 'onnx']:
                input_data = img.transpose((2, 0, 1))
                input_data = input_data.reshape(1, *input_data.shape).astype(np.float32)
                input_data = input_data / 255.
            else:
                input_data = img
            outputs = self.model.run([input_data])
            boxes, classes, scores = self.post_process(outputs)

            res_msg = self.trans_det_results(boxes, classes, scores, img_src.shape[0], img_src.shape[1], self.camera_matrix, [self.calib_width, self.calib_height], self.objs_in_meter, roi)
            res_msg['file_name'] = file_name
            res_msg['dataset'] = self.dataset_name
            if 'client_id' in msg:
                res_msg['client_id'] = msg['client_id']
            if 'file_name' in msg:
                res_msg['file_name'] = msg['file_name']
            if 'img_id' in msg:
                res_msg['img_id'] = msg['img_id']
            if 'img_total' in msg:
                res_msg['img_total'] = msg['img_total']
            res_msg['time_used'] = time.time() - t1
            if "img_id" in msg:
                res_msg["img_id"] = msg["img_id"]

            self._result_writer.publish(res_msg)

            if 'img_total' in msg and self.launch_next_emit:
                next_msg = def_msg('std_msgs::Boolean')
                next_msg['data'] = True
                self._next_writer.publish(next_msg)
                print('img_id', msg['img_id'])

            if self.b_use_shm:
                msg = self._show_writer.cvimg2sms_mem(img_src)
            msg['spirecv_msgs::2DTargets'] = res_msg
            self._show_writer.publish(msg)
            # END

        self.release()
        print('{} quit!'.format(self.__class__.__name__))
    
    def setup_model(self):
        model_path =  self.local_model_path
        target = self.target
        device_id = self.device_id if len(self.device_id) else None
        if model_path.endswith('.pt') or model_path.endswith('.torchscript'):
            platform = 'pytorch'
            from py_utils.pytorch_executor import Torch_model_container
            model = Torch_model_container(model_path)
        elif model_path.endswith('.rknn'):
            platform = 'rknn' 
            model = RKNN_model_container(model_path, target, device_id)
        elif model_path.endswith('onnx'):
            platform = 'onnx'
            from py_utils.onnx_executor import ONNX_model_container
            model = ONNX_model_container(model_path)
        else:
            assert False, "{} is not rknn/pytorch/onnx model".format(model_path)
        print('Model-{} is {} model, starting val'.format(model_path, platform))
        return model, platform

    def post_process(self, input_data):
        boxes, scores, classes_conf = [], [], []
        defualt_branch = 3
        pair_per_branch = len(input_data) // defualt_branch
        # Python 忽略 score_sum 输出
        for i in range(defualt_branch):
            boxes.append(self.box_process(input_data[pair_per_branch * i]))
            classes_conf.append(input_data[pair_per_branch * i + 1])
            scores.append(np.ones_like(input_data[pair_per_branch * i + 1][:, :1, :, :], dtype=np.float32))

        def sp_flatten(_in):
            ch = _in.shape[1]
            _in = _in.transpose(0,2,3,1)
            return _in.reshape(-1, ch)

        boxes = [sp_flatten(_v) for _v in boxes]
        classes_conf = [sp_flatten(_v) for _v in classes_conf]
        scores = [sp_flatten(_v) for _v in scores]

        boxes = np.concatenate(boxes)
        classes_conf = np.concatenate(classes_conf)
        scores = np.concatenate(scores)

        # filter according to threshold
        boxes, classes, scores = self.filter_boxes(boxes, scores, classes_conf)

        # nms
        nboxes, nclasses, nscores = [], [], []
        for c in set(classes):
            inds = np.where(classes == c)
            b = boxes[inds]
            c = classes[inds]
            s = scores[inds]
            keep = self.nms_boxes(b, s)

            if len(keep) != 0:
                nboxes.append(b[keep])
                nclasses.append(c[keep])
                nscores.append(s[keep])

        if not nclasses and not nscores:
            return None, None, None

        boxes = np.concatenate(nboxes)
        classes = np.concatenate(nclasses)
        scores = np.concatenate(nscores)

        return boxes, classes, scores

    def dfl(self, position):
        # Distribution Focal Loss (DFL)
        import torch
        x = torch.tensor(position)
        n, c, h, w = x.shape
        p_num = 4
        mc = c // p_num
        y = x.reshape(n, p_num, mc, h, w)
        y = y.softmax(2)
        acc_metrix = torch.tensor(range(mc)).float().reshape(1, 1, mc, 1, 1)
        y = (y * acc_metrix).sum(2)
        return y.numpy()

    def box_process(self, position):
        grid_h, grid_w = position.shape[2:4]
        col, row = np.meshgrid(np.arange(0, grid_w), np.arange(0, grid_h))
        col = col.reshape(1, 1, grid_h, grid_w)
        row = row.reshape(1, 1, grid_h, grid_w)
        grid = np.concatenate((col, row), axis=1)
        stride = np.array([self.imgsz[0] // grid_h, self.imgsz[1] // grid_w]).reshape(1, 2, 1, 1)

        position = self.dfl(position)
        box_xy  = grid + 0.5 - position[:, 0:2, :, :]
        box_xy2 = grid + 0.5 + position[:, 2:4, :, :]
        xyxy = np.concatenate((box_xy * stride, box_xy2 * stride), axis=1)

        return xyxy

    def filter_boxes(self, boxes, box_confidences, box_class_probs):
        """Filter boxes with object threshold.
        """
        box_confidences = box_confidences.reshape(-1)
        candidate, class_num = box_class_probs.shape

        class_max_score = np.max(box_class_probs, axis=-1)
        classes = np.argmax(box_class_probs, axis=-1)

        _class_pos = np.where(class_max_score* box_confidences >= self.confidence)
        scores = (class_max_score* box_confidences)[_class_pos]

        boxes = boxes[_class_pos]
        classes = classes[_class_pos]

        return boxes, classes, scores

    def nms_boxes(self, boxes, scores):
        """Suppress non-maximal boxes.
        # Returns
            keep: ndarray, index of effective boxes.
        """
        x = boxes[:, 0]
        y = boxes[:, 1]
        w = boxes[:, 2] - boxes[:, 0]
        h = boxes[:, 3] - boxes[:, 1]

        areas = w * h
        order = scores.argsort()[::-1]

        keep = []
        while order.size > 0:
            i = order[0]
            keep.append(i)

            xx1 = np.maximum(x[i], x[order[1:]])
            yy1 = np.maximum(y[i], y[order[1:]])
            xx2 = np.minimum(x[i] + w[i], x[order[1:]] + w[order[1:]])
            yy2 = np.minimum(y[i] + h[i], y[order[1:]] + h[order[1:]])

            w1 = np.maximum(0.0, xx2 - xx1 + 0.00001)
            h1 = np.maximum(0.0, yy2 - yy1 + 0.00001)
            inter = w1 * h1

            ovr = inter / (areas[i] + areas[order[1:]] - inter)
            inds = np.where(ovr <= self.nms_thresh)[0]
            order = order[inds + 1]
        keep = np.array(keep)
        return keep


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--config',
        type=str,
        default='default_params.json',
        help='SpireCV2 Config (.json)')
    parser.add_argument(
        '--job-name', '-j',
        type=str,
        default='live',
        help='SpireCV Job Name')
    parser.add_argument(
        '--ip',
        type=str,
        default='127.0.0.1',
        help='SpireMS Core IP')
    parser.add_argument(
        '--port',
        type=int,
        default=9094,
        help='SpireMS Core Port')
    args, unknown_args = parser.parse_known_args()
    if not os.path.isabs(args.config):
        current_path = os.path.abspath(__file__)
        params_dir = os.path.join(current_path[:current_path.find('spirecv-pro') + 11], 'params', 'spirecv2')
        args.config = os.path.join(params_dir, args.config)
    print("--config:", args.config)
    print("--job-name:", args.job_name)
    extra = get_extra_args(unknown_args)

    node = YOLOv11DetNode_Rknn(args.job_name, param_dict_or_file=args.config, ip=args.ip, port=args.port, **extra)
    node.join()
